# LASSO FUNCTIONS ---------------------------------------------------------------


library(hdm)

get_lasso_coeffs <- function(x) {
  x$beta[x$beta != 0] %>% names()
}

get_lasso_controls <- function(potential_controls, y, t, df, interact = "pair_includes_trans", group_control = TRUE, quiet = FALSE) {

  if (!is.numeric(df[[t]])) {
    stop("Treatment variable must be numeric")
  }

  #   Keep only the controls in the dataset, but warn if some are missing:
  missing_controls <- potential_controls[!potential_controls %in% names(df)]

  if (length(missing_controls) == length(potential_controls)) {
    stop("All controls are missing from the dataset")
  }

  if (length(missing_controls) > 0 & !quiet) {
    warning(paste0("The following controls are missing from the dataset: ", paste0(missing_controls, collapse = ", ")))
  }
  potential_controls <- potential_controls[potential_controls %in% names(df)]


  # If you want interaction terms, add them to the dataset and list
  if (!is.null(interact)) {
    controls_interact <- potential_controls %>% paste0("__", interact)

    df <- df %>%
      mutate(
        across(all_of(potential_controls),
               ~ .x * as.numeric(!!sym(interact)),
               .names = "{.col}__{interact}")
      )

  } else {
    controls_interact <- NULL
  }

  # If you want group level means, add them to list
  if (group_control) {
    controls_group <- potential_controls %>% paste0("__group_control")

    # Add group-level means to dataset
    df <- df %>%
      group_by(group_id) %>%
      mutate(
        across(all_of(potential_controls),
               list(group_control = ~ (sum_na(.x) - .x) / (n() - 1)),
               .names = "{.col}__group_control")
      ) %>%
      ungroup()

  } else {
    controls_group <- NULL
  }

  # Combine all controls
  all_controls <- c(
    potential_controls,
    controls_interact,
    controls_group
  )

  if (!quiet) {
  print(str_glue("All controls: {paste0(all_controls, collapse = ', ')}"))

  }


  # If there are missing values for some of the controls, stop and explain:
  df_controls <- df %>% select(any_of(all_controls))
  missing_vals <- colSums(is.na(df_controls)) %>% enframe() %>% filter(value > 0) %>%
    mutate(prop = value / nrow(df_controls))

  if (nrow(missing_vals) > 0) {
    print(missing_vals)
    stop("Controls have missing values")
  }



  # Run lasso
  lasso_y <- hdm::rlasso(
    formula = as.formula(paste0(y, " ~ ", paste0(all_controls, collapse = " + "))),
    data = df
  ) %>%
    get_lasso_coeffs()

  lasso_t <- hdm::rlasso(
    formula = as.formula(paste0(t, " ~ ", paste0(all_controls, collapse = " + "))),
    data = df
  ) %>%
    get_lasso_coeffs()

  selected_controls <- c(lasso_y, lasso_t) %>% unique()

  return(selected_controls)

}


# TEST ---------------------------------------------------------------

test <- FALSE

if (test) {

  control_vars <- c(
    demo_vars,
    "phase_2",
    # "high_sdb",
    "b11", "b12", "i1", "i2", "i3", "dq12",
    "normally_receives_delivery",
    "item_diff",
    "r2_reliability_diff",
    "r2_reliability_shown",
    "r2_reliability_benchmark"
    # "pair_includes_female"
  )

  r2_choices_for_controls <- r2_choices %>%

    # Get interaction terms
    mutate(
      across(all_of(control_vars), list(pair_includes_trans = ~ .x * as.numeric(pair_includes_trans)))
    ) %>%

    # Get group level means
    group_by(group_id) %>%
    mutate(
      across(all_of(control_vars), list(group_control = ~ (sum_na(.x) - .x) / (n() - 1)))
    ) %>%
    ungroup %>%
    glimpse

  lasso_controls(potential_controls = control_vars,
                 y = "r2_choose_comparator", t = "discussion_full",
                 interact = "pair_includes_trans", group_control = TRUE,
                 df = r2_choices_for_controls)




  feols_custom(
    r2_choose_comparator ~ pair_includes_trans * (discussion_full + stratum_id),
    data = r2_choices_for_controls %>% filter(discuss_type %in% c("control", "discussion_full")),
    fixef = c("stratum_id"),
    cluster = "group_id",
    stratum = "stratum_id",
    coef_omit = "stratum_id|Intercept|video",
    lasso = TRUE,
    lasso_options = list(
      potential_controls = control_vars,
      t = "discussion_full",
      interact = "pair_includes_trans",
      group_control = TRUE
    )
  )



}







# Go from model to table of lasso controls
get_lasso_control_df <- function(model, control_vec) {
  model$lasso_controls %>% enframe() %>%
    separate(value, into = c("value", "type"), "__") %>%
    mutate(group_control = ifelse(type %in% "group_control", TRUE, FALSE),
           interact = ifelse(!group_control & !is.na(type), type, NA)) %>%
    select(-name) %>%
    left_join(
      control_vec %>% enframe(),
      by = "value"
    ) %>%
    mutate(interact = interact %>%
      str_replace_all("pair_includes_trans", "Worker is trans")) %>%
    mutate(name = ifelse(group_control, paste0("Group-level control: ", name), name),
         name = ifelse(!is.na(interact), paste0(interact, " x ", name), name))
}

# Go from list of models to combined table of lasso controls
combine_lasso_control_df <- function(models, control_vec) {
  models %>%
    map(get_lasso_control_df, control_vars_with_names) %>%
    set_names(1:length(.)) %>%
    keep(~ nrow(.x) > 0) %>%
    bind_rows(.id = "column") %>%
    mutate()
}


# lasso_controls <- list()

# Update the lasso controls list on the disk
update_lasso_controls <- function(models, table_name, remove = NULL, file = "data/cleaned/lasso_control.RData") {
  load(file)
  if(!is.null(remove)) lasso_controls[[remove]] <- NULL
  lasso_controls[[table_name]] <- models %>% combine_lasso_control_df(control_vars_with_names)
  save(lasso_controls, file = file)
}